####################################################
#Author: Kelli Marquardt
#Purpose: Train and Run ML algorithm to predict Qij for all visits; produce word cloud and and validation figure (fig 2 and fig c1) 

# Inputs:
#- data/note_dat_fake_gptnotes.csv and note_dat_fake_base.csv   

# Outputs:
#- data/intermediate/: 
    #cleaned_appt_notes_fake.csv and note_dat_fake_wQ.csv
#- output/logs/step1_predictQ_log.txt 
#- output/figures/ fig_c1.png, fig_2a.png, fig_2b.png 

####################################################




############################
#0 load required packages
############################
rm(list = ls(all.names = TRUE))

#load packages 
library(dplyr) 
library(ggplot2)
library(textclean) #for fixing contractions, add_comma_space
library(data.table) #fread and fwrite
library(tidytext) #text mining (e.g. tokenizing) and stop_words
library(SnowballC) #for stemming
library(tidyr) #for pivot wider
library(mlr) #for the machine learning analysis
library(randomForest)
library(parallel) #for running code in parallel
library(parallelMap) #for running code in parallel
library(pbapply) #for apply function progress bar 
library(stringr) #for str_squish
library(hunspell) #for spell check and replacement
library(ggwordcloud) #for word cloud visual


############################
#0 Make sure medical dictionary is downloaded and in use (see instructions in README and Step0_InstallPackages.R)
############################
#check that medical dictionary was downloaded successfully 
Sys.setenv(DICPATH = file.path("..", "data", "hunspell-en-med-glut-workaround-master"))
dictionary("en_US")
med_dict_check=hunspell_suggest("Olanzapinee")
if (med_dict_check[[1]][1]!="Olanzapine") {
  stop("Make sure medical dictionary is saved in DICPATH")
}
rm(med_dict_check)


############################
#0 Define functions used in estimation and simulation 
############################

#replace contractions, spaces after comma, lowercase, remove white space 
#note- the gsub was specific to cleaning true de-identified data so commented out in replication code
clean_notes_specific = function(x) {
  x <- textclean::add_comma_space(textclean::replace_contraction(x))
  x <- tolower(x)
  #x <- gsub("[^a-zA-Z0-9 ]|[0-9]|xx+", "  ", x)
  x <- stringr::str_squish(x)
  return(x)
}

# Suggest corrections for misspelled words
get_first_suggestion = function(word) {
  suggestions <- hunspell::hunspell_suggest(word)[[1]]
  if (length(suggestions) > 0) suggestions[1] else word
}


#################################################
#Step 1: read in the notes data and process 
  #clean, spell-check, re-clean, and save in data/intermediate to avoid double process time. 
#################################################


###Loading in the data (both the gpt synthetic notes and the other variables)
gpt_note_dat = read.csv(file.path("..", "data", "note_dat_fake_gptnotes.csv"), stringsAsFactors = FALSE)
other_note_dat = read.csv(file.path("..", "data", "note_dat_fake_base.csv"), stringsAsFactors = FALSE)

#merge in the gpt generated notes and define single dataset 
note_dat=gpt_note_dat%>%
  select(pat_id, visit_id, note_text)%>%
  right_join(other_note_dat, by=c("pat_id","visit_id"))%>%
  rename(notes=note_text)
rm(gpt_note_dat, other_note_dat)

###cleaning the data 

#run clean_notes_specific
clean_dta=note_dat%>%
  mutate(notes=clean_notes_specific(notes)) #Note: for true data, this was done in parallel to reduce time (see commented block below)

# cl = makeCluster(10) #create cluster cl (10 cores)
# clean_dta = clean_dta %>%
#   mutate(notes = pbsapply(notes, clean_notes_specific, cl = cl))
# stopCluster(cl)  # stop cluster
# gc()

#####
#applying spell check to unique words 

# Tokenize notes into words and get unique words
clean_dta_tokens=clean_dta %>%
  unnest_tokens(word,notes, to_lower=T, drop=T, strip_punct=T)

spell_check=clean_dta_tokens%>%
  select(word)%>%
  distinct()

# Capitalize all words as hunspell_check cares about capitalization
spell_check=spell_check%>%
  mutate(word=toupper(word))%>%
  distinct()

# Check spelling and filter misspelled words
spell_check = spell_check %>%
  mutate(correct = hunspell_check(word)) %>%
  filter(!correct)

# Apply suggestions efficiently using vectorized function 
spell_check = spell_check %>%
  group_by(word)%>%
  mutate(suggest = get_first_suggestion(word))%>%
  ungroup() #Note: for true data, this was done in parallel to reduce time (see commented block below)

# cl = makeCluster(10) #create cluster cl (10 cores)
# spell_check = spell_check %>%
#   mutate(suggest = pbsapply(word, get_first_suggestion, cl=cl))
# stopCluster(cl)  # stop cluster
# gc()

#put all back to lowercase
spell_check=spell_check%>%
  mutate(suggest = tolower(gsub("[^a-zA-Z0-9 ]", "", suggest)))%>%
  mutate(word = tolower(word))
  
# Merge suggestions with original tokenized notes data and replace mis-spelled word with suggested word
clean_dta_tokens=left_join(clean_dta_tokens,spell_check, by=c("word"))%>%
  mutate(word=ifelse(!is.na(suggest), suggest, word) ) 

#put back into un-tokenized form 
clean_dta = clean_dta_tokens %>% 
  select(pat_id, visit_id, assigned_label, word) %>%
  group_by(pat_id, visit_id) %>% 
  summarize(notes = str_c(word, collapse = " "))%>%
  ungroup()


### clean one more time
clean_dta=clean_dta%>%
  ungroup()%>%
  mutate(notes=clean_notes_specific(notes)) #Note: for true data, this was done in parallel to reduce time (see commented block below)

# cl = makeCluster(10) #create cluster cl (10 cores)
# clean_dta = clean_dta %>%
#   mutate(notes = pbsapply(notes, clean_notes_specific, cl = cl))
# stopCluster(cl)  # stop cluster
# gc() 

#read back in original vars needed for rest of code ( for now, just need assigned_label ) 
clean_dta=note_dat%>%
  select(pat_id, visit_id, assigned_label)%>%
  left_join(clean_dta, by=c("pat_id","visit_id"))


### saving as intermediate output for use in step2 code 
file_name = file.path("..", "data", "intermediate", "cleaned_appt_notes_fake.csv")
fwrite(clean_dta,file_name, row.names = FALSE)


#clean up 
rm(clean_dta_tokens, spell_check)
rm(file_name)
###############################################


#################################################
#Step 2: test, train, and run the machine learning algorithm 
#################################################
##
#open log file 
file_name = file.path("..", "output", "logs", "step1_predictQ_log.txt")
sink(file_name)

print(Sys.Date())


#set seed and user-options 
seed=123456
set.seed(seed)

number_of_words=150 #number of most frequent words to start with when determining features 
  #note: goal is 40 features, but many words appear in both groups, so start bigger 

grid_search_n=100 #how many hyper-parameters to search over when tuning
min_measure="fdr" #target to minimize- false discovery rate 

###############
#Prepare data 
  
  ###tokenize and remove stop_words
  clean_data_token = clean_dta %>%
    unnest_tokens(word, notes) %>%
    anti_join(stop_words, by="word")
  
  ###stem (and save a word to stem mapping for later use in word cloud)
  clean_data_token$stem=wordStem(clean_data_token$word, language="en")%>%
    unlist()
  stem_to_word_dat=clean_data_token%>%select(word, stem)
  
  #replace word with stem 
  clean_data_token=clean_data_token%>%
    mutate(word=stem)%>%
    select(-stem)
  
  ###create a single group variable for each pat_id, visit_id 
  clean_data_token=clean_data_token %>% 
    group_by(pat_id, visit_id) %>% 
    mutate(pat_visit_id=cur_group_id())%>%
    ungroup()

###############
#Get the labeled set and determine features 
  
  labeled_set=clean_data_token[which(!is.na(clean_data_token$assigned_label)),]

  ###Determine top words
  top_words = labeled_set %>%
    group_by(assigned_label) %>%
    mutate(q1_word_count = n()) %>%
    group_by(assigned_label, word) %>%
    mutate(word_count = n(),
           word_pct = word_count / q1_word_count * 100) %>%
    select(word, assigned_label, q1_word_count, word_count, word_pct) %>%
    distinct() %>%
    ungroup()%>%
    arrange(desc(word_pct)) %>%
    slice_head(n = number_of_words)%>%
    select(assigned_label, word, word_pct)
  
  #remove words that are in more than one group
  top_words_2 = top_words %>%
    ungroup() %>%
    group_by(word) %>%
    mutate(both = n()) %>%
    filter(both < 2) %>%
    select(assigned_label, word, word_pct)
  #pick the top 20 words distinct for each group 
  top_words_2_all=top_words_2 %>%
    group_by(assigned_label) %>%
    arrange(desc(word_pct)) %>%
    slice_head(n = 20)%>%
    select(assigned_label,top_word=word)
  
  #create lists of the top words per group
  q1_words = lapply(top_words_2_all[top_words_2_all$assigned_label ==1,], as.character)
  q0_words = lapply(top_words_2_all[top_words_2_all$assigned_label == 0,], as.character)


  ###creating features for each note
  #length of the note, weighted vector of top words 
  
  #length of note
  note_length=labeled_set %>%
    group_by(pat_visit_id) %>%
    mutate(word_count = n()) %>%
    select(c(pat_visit_id, word_count)) %>%
    distinct()

  ###build the dfm
  #get dfm, but keep only if in top_words , and combine with note length 
  train_feat=labeled_set[(labeled_set$word %in% top_words_2_all$top_word),]
  dtm_temp=train_feat %>%
    group_by(pat_visit_id) %>%
    count(pat_visit_id, word, sort=T) %>%
    ungroup() %>%
    pivot_wider(names_from= word, 
                values_from=n)


  #make sure there are 40 predictive words, else exit
  if(ncol(dtm_temp)!=41){
    return("DID NOT RETURN 40 PREDICTIVE WORDS AS FEATURES. increase number_of_words=150 in step2 block intro ")
  }
  
  #make sure none of the features are called "function", else rename 
  if(!is.na(which(colnames(dtm_temp)=="function")[1])){
    names(dtm_temp)[names(dtm_temp) == "function"]="function_word"
  }
    
  features=left_join(note_length,dtm_temp,by="pat_visit_id")

  #add back in the assigned label for the labeled set
  id=labeled_set %>%
    select(pat_visit_id, assigned_label) %>%
    distinct()

  features=left_join(id,features, by="pat_visit_id")

  #give 0 to NA and divide by word_count
  features[is.na(features)]=0
  names(features)
  features[,4:ncol(features)]=features[,4:ncol(features)]/features$word_count
  features$assigned_label=as.factor(features$assigned_label)
  features_final=as.data.frame(features)

###############
#Separate into training and test sets

  leave_out_sample_prob=.2 #note: .1 is used for true data, but replication code based on much smaller fake data & thus needs more in test sample

  #make sure there are both positive and negative labels in the test set   
  ids_pos = features_final$pat_visit_id[features_final$assigned_label == 1]
  ids_neg = features_final$pat_visit_id[features_final$assigned_label == 0]
  
  test_id = c(
    sample(ids_pos, size = ceiling(leave_out_sample_prob * length(ids_pos)), replace = FALSE),
    sample(ids_neg, size = ceiling(leave_out_sample_prob * length(ids_neg)), replace = FALSE)
  ) 
  
  features_final_test=features_final[which(features_final$pat_visit_id %in% test_id),]
  features_final_train=features_final[which(!(features_final$pat_visit_id %in% test_id)),]
  
###############
#Prepare for ML algorithm 
  
  ###First create a task - for both the training set and the test set 
  #makeClassifTask(data, target, positive) 
  #what data to use : data=
  #what is target factor var: target=
  #which value is to be considered positive: positive=

  trainTask=makeClassifTask(data=features_final_train, target="assigned_label", positive=1)
  testTask=makeClassifTask(data=features_final_test, target="assigned_label", positive=1)
  
  #do not include pat_visit_id as a feature (all others remain)
  trainTask=dropFeatures(task=trainTask, "pat_visit_id")
  testTask=dropFeatures(task=testTask, "pat_visit_id")
  

  ### Create the Learner - choose the algorithm to use 
  #rf_all=makeLearner("classif.randomForest", predict.type = "prob", par.vals=list())
  #classif.randomForest means it will do a random Forest algorithm
  #predict.type="prob" means it will return prob(1) instead of 1,0    
  #par.vals=list(): specifies the hyperparameters to use 
  #*see getParamSet("classif.randomForest") for options
  
  
  #Important Parameters: 
  #ntree: number of trees to grow (default is 500) 
  #mtry: how many variables to select at node split (default is sqrt(# of col))
  #nodesize: how many observations wanted in terminal nodes (high nodesize -> short tree depth)
  #getParamSet("classif.randomForest")
  
  rf=makeLearner("classif.randomForest", predict.type = "prob")

  

###############
#Tune hyperparameters using the train set 
  
  ####tune
  #these are based on rule of thumb at https://bradleyboehmke.github.io/HOML/random-forest.html
  #ntree: start with 10*# of features and go from there 
  #mtry: five evenly spaced values across the range 2-p centered at sqrt(p) where p# number of features 
  #nodesize: try values between 1 and 10 
  
  rf_param= makeParamSet(
    makeIntegerParam("ntree", lower=300, upper=500),
    makeIntegerParam("mtry", lower=2, upper=40),
    makeIntegerParam("nodesize", lower=1, upper=10)
  )
  
  rancontrol=makeTuneControlRandom(maxit=grid_search_n) #n iterations of random grid search
  
  #learner=rf: use random forest algorithm defined above
  #resampling=hout: using 1/3 hold out sample 
  #par.set=rf_param: use the param limits defined above
  #control=rancontrol: run for max of 100 iterations as defined in rancontrol var
  #measure=fdr: goal is to minimize the false discovery rate = fp/(fp+tp) 
  
  hout =  makeResampleDesc("Holdout", split = 2/3, stratify = TRUE)
  
  
 # parallelStart(mode="socket", cpu=10, level="mlr.tuneParams") #commented out for fake note data, but needed for true data
    rf_tune_all=tuneParams(learner=rf, resampling = hout, task=trainTask, par.set=rf_param, control=rancontrol, measures=fdr )
  #parallelStop() #commented out for fake note data, but needed for true data

  #output tuned parameters to log file 
    cat(" \n tuned parameter values \n ")
  tuned.params=c(rf_tune_all$y ,
                 rf_tune_all$x )
  print(tuned.params)
  
###############
#Train model on training set using the tuned hyperparameters and apply to test set
  
  ###Use the determined hyperparameters to train and test the model
  rf.tree=setHyperPars(rf, par.vals=rf_tune_all$x)
  ml_tuned_train=train(rf.tree, trainTask)
  
  #show how it preforms on the test set 
  result_ml_tuned_test=predict(ml_tuned_train, testTask)
  
  #print out the confusion matrix and FDR 
  cat(" \n Confusion Matrix for Test Set (with tuned parameters)  \n ")
  calculateConfusionMatrix(result_ml_tuned_test)
  
  cat(" \n Performance Measures for Test Set (with tuned parameters) \n ")
  performance(
    result_ml_tuned_test,
    measures = list(fdr, ppv, tpr, fnr)
  )

sink() 
#end log file. everything after is saved independently 


###############
#Apply the ML algorithm to the full data, including those with missing assigned_label  
  
  #Prepping the data
  unlabled_all=clean_data_token 

  note_length_un=unlabled_all %>%
    group_by(pat_visit_id) %>%
    mutate(word_count = n()) %>%
    select(c(pat_visit_id, word_count)) %>%
    distinct()

  unlabled_all_feat=unlabled_all[(unlabled_all$word %in% top_words_2_all$top_word),]

  dtm_temp_un_all=unlabled_all_feat %>%
    group_by(pat_visit_id) %>%
    count(pat_visit_id, word, sort=T) %>%
    ungroup() %>%
    pivot_wider(names_from=word, 
                values_from=n)
  
  #make sure none of the features are called "function", else rename 
  if(!is.na(which(colnames(dtm_temp_un_all)=="function")[1])){
    names(dtm_temp_un_all)[names(dtm_temp_un_all) == "function"]="function_word"
  }


  feature_temp_un_all=left_join(note_length_un, dtm_temp_un_all, by="pat_visit_id")


  #give 0 to NA and divide by word_count
  feature_temp_un_all[is.na(feature_temp_un_all)]=0
  
  #only devide feature word columns
  #names(feature_temp_un_all)
  feature_temp_un_all[,3:ncol(feature_temp_un_all)]=feature_temp_un_all[,3:ncol(feature_temp_un_all)]/feature_temp_un_all$word_count
  
  feature_final_un_all=as.data.frame(feature_temp_un_all) 
  
  #make sure the column order is the same
  t=colnames(features_final_train)
  t=t[c(1,3:43)]
  feature_final_un_all=feature_final_un_all%>%
    select(all_of(t))
 
  
  #####predict with trained & tuned ml algorithm
  
  result2_ml_all=predict(ml_tuned_train, newdata=feature_final_un_all)
  result2_ml_all_df=result2_ml_all$data
  
  ####
  #create dataframe to output 
  
  temp_all=cbind(result2_ml_all_df,feature_final_un_all)
  temp_all=temp_all%>%
    select(prob.1, pat_visit_id)
  
  colnames(temp_all)=c("prob1", "pat_visit_id")

  #merge back in with the main data 
  
  #first get pat_visit_id to pat_id, visit_id crosswalk 
  id_temp=clean_data_token %>%
    select(pat_id, visit_id, pat_visit_id) %>%
    distinct()
  
  #merge back in to main and save 
  ml_final=note_dat%>%
    left_join(id_temp,  by=c("pat_id","visit_id"))%>%
    left_join(temp_all, by="pat_visit_id")%>%
    select(-c(pat_visit_id, notes))
  
  
  write.csv(ml_final, file.path("..", "data", "intermediate", "note_dat_fake_wQ.csv"), row.names = FALSE)

    #clean up (keep only things needed for remaining 2 tasks: word cloud and validation figure)
    #word cloud: needs words and frequency of words in the text 
    #validation: just need behav.val, probQ1

  keep_objects = c("ml_final", "clean_dta", "note_dat", "q1_words", "q0_words", 
                   "clean_data_token", "stem_to_word_dat")
    rm(list = setdiff(ls(envir = .GlobalEnv), keep_objects),
     envir = .GlobalEnv)
  
    
    
#################################################
#Step 3: Validation Figure based on behav.val 
#################################################
    
val_data=ml_final%>%
      mutate(behav.val=factor(behav.val, levels = c(0, 1), labels = c("behav.val=0", "behav.val=1")))
      
      
val_ml_plot=ggplot(val_data, aes(x = prob1, fill = (behav.val))) +
     geom_density(alpha = 0.6, color = "black") +
       scale_fill_manual(values = c("behav.val=0" = "orchid", "behav.val=1" = "seagreen")) +
       labs(
         title = "",
         x = "ML- Predicted Probability of Behavioral Assessment",
         y = "Density",
         fill = ""
       ) +
       theme_classic(base_size = 12)+
       theme(
         legend.position = "bottom",
         legend.title = element_text(size = 12),
         legend.text = element_text(size = 12))
      
#val_ml_plot
     
#save the figure 
ggsave(file.path("..", "output", "figures", "fig_c1.png"),
       plot = val_ml_plot, width = 6, height = 5, dpi = 300)
    
rm(val_ml_plot, val_data)    
    
    
#################################################
#Step 4: Word Cloud
#################################################

# stems for each label
stems_q1 = q1_words$top_word
stems_q0 = q0_words$top_word

#determine how frequently they appear in the assigned_label group
freq_q1=clean_data_token%>%
  filter(assigned_label==1)%>%
  mutate(word_in_q1=1, 
         word_in_q1_top=as.integer(word %in% stems_q1))%>%
  group_by(word)%>%
  summarise(n_all=sum(word_in_q1),
            n_top=sum(word_in_q1_top))%>%
  ungroup()%>%
  mutate(word_freq=n_top/sum(n_all))

freq_q0=clean_data_token%>%
  filter(assigned_label==0)%>%
  mutate(word_in_q0=1, 
         word_in_q0_top=as.integer(word %in% stems_q0))%>%
  group_by(word)%>%
  summarise(n_all=sum(word_in_q0),
            n_top=sum(word_in_q0_top))%>%
  ungroup()%>%
  mutate(word_freq=n_top/sum(n_all))

#use the count instead of the frequency for the word cloud 
freq_q1=freq_q1%>%
  filter(word_freq!=0)%>%
  select(word, n_top)
freq_q0=freq_q0%>%
  filter(word_freq!=0)%>%
  select(word, n_top)

#####
#aside- need to map each stem to its most frequent non-stemmed word

stem_to_word_dat = stem_to_word_dat %>%
  count(stem, word, sort = TRUE) %>%
  group_by(stem) %>%
  slice_max(n, n = 1, with_ties = F) %>%   # most frequent original word per stem
  ungroup()%>%
  rename(word_to_replace=word)%>%
  rename(word=stem)%>%
  select(-n)

freq_q1=freq_q1%>%
  left_join(stem_to_word_dat, by="word")%>%
  select(-word)%>%
  rename(word=word_to_replace)
freq_q0=freq_q0%>%
  left_join(stem_to_word_dat, by="word")%>%
  select(-word)%>%
  rename(word=word_to_replace)




####################
#create word cloud for Q1 words 
df=data.frame(word= freq_q1$word, freq=freq_q1$n_top)

set.seed(12345)  # For reproducibility

#pick random colors and rotation
df = df %>%
  mutate(
    angle = sample(c(0, 90), n(), replace = TRUE), # Random rotation
    color = sample(colors(), n(), replace = TRUE)  # Random colors
  )

q1_word=ggplot(df, aes(label = word, size = freq, angle = angle, color = color)) +
  geom_text_wordcloud(shape = "square") +
  scale_size_area(max_size = 10) +
  theme_void() +
  theme(
    plot.margin = ggplot2::margin(0, 0, 0, 0)
  )

ggsave(file.path("..", "output", "figures", "fig_2b.png"),
       plot = q1_word, width = 4, height = 3, dpi = 600)

####################
#create word cloud for Q0 words 
df=data.frame(word= freq_q0$word, freq=freq_q0$n_top)

set.seed(12345)  

#pick random colors and rotation
df = df %>%
  mutate(
    angle = sample(c(0, 90), n(), replace = TRUE), # Random rotation
    color = sample(colors(), n(), replace = TRUE)  # Random colors
  )

q0_word=ggplot(df, aes(label = word, size = freq, angle = angle, color = color)) +
  geom_text_wordcloud(shape = "square") +
  scale_size_area(max_size = 10) +
  theme_void() +
  theme(
    plot.margin = ggplot2::margin(0, 0, 0, 0)
  )
ggsave(file.path("..", "output", "figures", "fig_2a.png"),
       plot = q0_word, width = 4, height = 3, dpi = 600)


#END OF SCRIPT

   